import tensorflow as tf
import tree

from keras_core.utils.nest import pack_sequence_as


def rnn(
    step_function,
    inputs,
    initial_states,
    go_backwards=False,
    mask=None,
    constants=None,
    unroll=False,
    input_length=None,
    time_major=False,
    zero_output_for_mask=False,
    return_all_outputs=True,
):
    """Iterates over the time dimension of a tensor.

    Args:
        step_function: RNN step function.
            Args;
                `input`; Tensor with shape `(samples, ...)` (no time dimension),
                    representing input for the batch of samples at a certain
                    time step.
                `states`; List of tensors.
            Returns;
                `output`; Tensor with shape `(samples, output_dim)`
                    (no time dimension).
                `new_states`; List of tensors, same length and shapes
                    as 'states'. The first state in the list must be the
                    output tensor at the previous timestep.
        inputs: Tensor of temporal data of shape `(samples, time, ...)`
            (at least 3D), or nested tensors, and each of which has shape
            `(samples, time, ...)`.
        initial_states: Tensor with shape `(samples, state_size)`
            (no time dimension), containing the initial values for the states
            used in the step function. In the case that state_size is in a
            nested shape, the shape of initial_states will also follow the
            nested structure.
        go_backwards: Boolean. If `True`, do the iteration over the time
            dimension in reverse order and return the reversed sequence.
        mask: Binary tensor with shape `(samples, time, 1)`,
            with a zero for every element that is masked.
        constants: List of constant values passed at each step.
        unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
        input_length: An integer or a 1-D Tensor, depending on whether
            the time dimension is fixed-length or not. In case of variable
            length input, it is used for masking in case there's no mask
            specified.
        time_major: Boolean. If `True`, the inputs and outputs will be in shape
            `(timesteps, batch, ...)`, whereas in the False case, it will be
            `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
            efficient because it avoids transposes at the beginning and end of
            the RNN calculation. However, most TensorFlow data is batch-major,
            so by default this function accepts input and emits output in
            batch-major form.
        zero_output_for_mask: Boolean. If `True`, the output for masked timestep
            will be zeros, whereas in the `False` case, output from previous
            timestep is returned.
        return_all_outputs: Boolean. If `True`, return the recurrent outputs for
            all timesteps in the sequence. If `False`, only return the output
            for the last timestep (which consumes less memory).

    Returns:
        A tuple, `(last_output, outputs, new_states)`.
            - `last_output`: the latest output of the rnn,
                with shape `(samples, ...)`.
            - `outputs`:
                - If `return_all_outputs=True`: a tensor with shape
                  `(samples, time, ...)` where each entry `outputs[s, t]` is the
                  output of the step function at time `t` for sample `s`
                - Else, a tensor equal to `last_output` with shape
                  `(samples, 1, ...)`
            - `new_states`: list of tensors, latest states returned by
                the step function, of shape `(samples, ...)`.
    """
    input_length = input_length or inputs.shape[1]

    def swap_batch_timestep(input_t):
        # Swap the batch and timestep dim for the incoming tensor.
        axes = list(range(len(input_t.shape)))
        axes[0], axes[1] = 1, 0
        return tf.transpose(input_t, axes)

    if not time_major:
        inputs = tree.map_structure(swap_batch_timestep, inputs)

    flattened_inputs = tree.flatten(inputs)
    time_steps = flattened_inputs[0].shape[0]
    time_steps_t = tf.shape(flattened_inputs[0])[0]

    for input_ in flattened_inputs:
        input_.shape.with_rank_at_least(3)

    if mask is not None:
        if mask.dtype != tf.bool:
            mask = tf.cast(mask, tf.bool)
        if len(mask.shape) == 2:
            mask = tf.expand_dims(mask, axis=-1)
        if not time_major:
            mask = swap_batch_timestep(mask)

    if constants is None:
        constants = []

    # tf.where needs its condition tensor to be the same shape as its two
    # result tensors, but in our case the condition (mask) tensor is
    # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
    # So we need to broadcast the mask to match the shape of inputs.
    # That's what the tile call does, it just repeats the mask along its
    # second dimension n times.
    def _expand_mask(mask_t, input_t, fixed_dim=1):
        if tree.is_nested(mask_t):
            raise ValueError(
                f"mask_t is expected to be tensor, but got {mask_t}"
            )
        if tree.is_nested(input_t):
            raise ValueError(
                f"input_t is expected to be tensor, but got {input_t}"
            )
        rank_diff = len(input_t.shape) - len(mask_t.shape)
        for _ in range(rank_diff):
            mask_t = tf.expand_dims(mask_t, -1)
        multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
        return tf.tile(mask_t, multiples)

    if unroll:
        if not time_steps:
            raise ValueError("Unrolling requires a fixed number of timesteps.")
        states = tuple(initial_states)
        successive_states = []
        successive_outputs = []

        # Process the input tensors. The input tensor need to be split on the
        # time_step dim, and reverse if go_backwards is True. In the case of
        # nested input, the input is flattened and then transformed
        # individually.  The result of this will be a tuple of lists, each of
        # the item in tuple is list of the tensor with shape (batch, feature)
        def _process_single_input_t(input_t):
            input_t = tf.unstack(input_t)  # unstack for time_step dim
            if go_backwards:
                input_t.reverse()
            return input_t

        if tree.is_nested(inputs):
            processed_input = tree.map_structure(
                _process_single_input_t, inputs
            )
        else:
            processed_input = (_process_single_input_t(inputs),)

        def _get_input_tensor(time):
            inp = [t_[time] for t_ in processed_input]
            return pack_sequence_as(inputs, inp)

        if mask is not None:
            mask_list = tf.unstack(mask)
            if go_backwards:
                mask_list.reverse()

            for i in range(time_steps):
                inp = _get_input_tensor(i)
                mask_t = mask_list[i]
                output, new_states = step_function(
                    inp, tuple(states) + tuple(constants)
                )
                tiled_mask_t = _expand_mask(mask_t, output)

                if not successive_outputs:
                    prev_output = tf.zeros_like(output)
                else:
                    prev_output = successive_outputs[-1]

                output = tf.where(tiled_mask_t, output, prev_output)

                flat_states = tree.flatten(states)
                flat_new_states = tree.flatten(new_states)
                tiled_mask_t = tuple(
                    _expand_mask(mask_t, s) for s in flat_states
                )
                flat_final_states = tuple(
                    tf.where(m, s, ps)
                    for m, s, ps in zip(
                        tiled_mask_t, flat_new_states, flat_states
                    )
                )
                states = pack_sequence_as(states, flat_final_states)

                if return_all_outputs:
                    successive_outputs.append(output)
                    successive_states.append(states)
                else:
                    successive_outputs = [output]
                    successive_states = [states]
            last_output = successive_outputs[-1]
            new_states = successive_states[-1]
            outputs = tf.stack(successive_outputs)

            if zero_output_for_mask:
                last_output = tf.where(
                    _expand_mask(mask_list[-1], last_output),
                    last_output,
                    tf.zeros_like(last_output),
                )
                outputs = tf.where(
                    _expand_mask(mask, outputs, fixed_dim=2),
                    outputs,
                    tf.zeros_like(outputs),
                )

        else:  # mask is None
            for i in range(time_steps):
                inp = _get_input_tensor(i)
                output, states = step_function(
                    inp, tuple(states) + tuple(constants)
                )
                if return_all_outputs:
                    successive_outputs.append(output)
                    successive_states.append(states)
                else:
                    successive_outputs = [output]
                    successive_states = [states]
            last_output = successive_outputs[-1]
            new_states = successive_states[-1]
            outputs = tf.stack(successive_outputs)

    else:  # Unroll == False
        states = tuple(initial_states)

        # Create input tensor array, if the inputs is nested tensors, then it
        # will be flattened first, and tensor array will be created one per
        # flattened tensor.
        input_ta = tuple(
            tf.TensorArray(
                dtype=inp.dtype,
                size=time_steps_t,
                tensor_array_name=f"input_ta_{i}",
            )
            for i, inp in enumerate(flattened_inputs)
        )
        input_ta = tuple(
            ta.unstack(input_)
            if not go_backwards
            else ta.unstack(tf.reverse(input_, [0]))
            for ta, input_ in zip(input_ta, flattened_inputs)
        )

        # Get the time(0) input and compute the output for that, the output will
        # be used to determine the dtype of output tensor array. Don't read from
        # input_ta due to TensorArray clear_after_read default to True.
        input_time_zero = pack_sequence_as(
            inputs, [inp[0] for inp in flattened_inputs]
        )
        # output_time_zero is used to determine the cell output shape and its
        # dtype.  the value is discarded.
        output_time_zero, _ = step_function(
            input_time_zero, tuple(initial_states) + tuple(constants)
        )

        output_ta_size = time_steps_t if return_all_outputs else 1
        output_ta = tuple(
            tf.TensorArray(
                dtype=out.dtype,
                size=output_ta_size,
                element_shape=out.shape,
                tensor_array_name=f"output_ta_{i}",
            )
            for i, out in enumerate(tree.flatten(output_time_zero))
        )

        time = tf.constant(0, dtype="int32", name="time")

        if input_length is None:
            max_iterations = time_steps_t
        else:
            max_iterations = tf.reduce_max(input_length)

        while_loop_kwargs = {
            "cond": lambda time, *_: time < time_steps_t,
            "maximum_iterations": max_iterations,
            "parallel_iterations": 32,
            "swap_memory": True,
        }
        if mask is not None:
            if go_backwards:
                mask = tf.reverse(mask, [0])

            mask_ta = tf.TensorArray(
                dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta"
            )
            mask_ta = mask_ta.unstack(mask)

            def masking_fn(time):
                return mask_ta.read(time)

            def compute_masked_output(mask_t, flat_out, flat_mask):
                tiled_mask_t = tuple(
                    _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))
                    for o in flat_out
                )
                return tuple(
                    tf.where(m, o, fm)
                    for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)
                )

        elif isinstance(input_length, tf.Tensor):
            if go_backwards:
                max_len = tf.reduce_max(input_length, axis=0)
                rev_input_length = tf.subtract(max_len - 1, input_length)

                def masking_fn(time):
                    return tf.less(rev_input_length, time)

            else:

                def masking_fn(time):
                    return tf.greater(input_length, time)

            def compute_masked_output(mask_t, flat_out, flat_mask):
                return tuple(
                    tf.where(mask_t, o, zo)
                    for (o, zo) in zip(flat_out, flat_mask)
                )

        else:
            masking_fn = None

        if masking_fn is not None:
            # Mask for the T output will be base on the output of T - 1. In the
            # case T = 0, a zero filled tensor will be used.
            flat_zero_output = tuple(
                tf.zeros_like(o) for o in tree.flatten(output_time_zero)
            )

            def _step(time, output_ta_t, prev_output, *states):
                """RNN step function.

                Args:
                    time: Current timestep value.
                    output_ta_t: TensorArray.
                    prev_output: tuple of outputs from time - 1.
                    *states: List of states.

                Returns:
                    Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
                """
                current_input = tuple(ta.read(time) for ta in input_ta)
                # maybe set shape.
                current_input = pack_sequence_as(inputs, current_input)
                mask_t = masking_fn(time)
                output, new_states = step_function(
                    current_input, tuple(states) + tuple(constants)
                )
                # mask output
                flat_output = tree.flatten(output)
                flat_mask_output = (
                    flat_zero_output
                    if zero_output_for_mask
                    else tree.flatten(prev_output)
                )
                flat_new_output = compute_masked_output(
                    mask_t, flat_output, flat_mask_output
                )

                # mask states
                flat_state = tree.flatten(states)
                flat_new_state = tree.flatten(new_states)
                flat_final_state = compute_masked_output(
                    mask_t, flat_new_state, flat_state
                )
                new_states = pack_sequence_as(new_states, flat_final_state)

                ta_index_to_write = time if return_all_outputs else 0
                output_ta_t = tuple(
                    ta.write(ta_index_to_write, out)
                    for ta, out in zip(output_ta_t, flat_new_output)
                )

                return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple(
                    new_states
                )

            final_outputs = tf.while_loop(
                body=_step,
                loop_vars=(time, output_ta, flat_zero_output) + states,
                **while_loop_kwargs,
            )
            # Skip final_outputs[2] which is the output for final timestep.
            new_states = final_outputs[3:]
        else:

            def _step(time, output_ta_t, *states):
                """RNN step function.

                Args:
                    time: Current timestep value.
                    output_ta_t: TensorArray.
                    *states: List of states.

                Returns:
                    Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
                """
                current_input = tuple(ta.read(time) for ta in input_ta)
                current_input = pack_sequence_as(inputs, current_input)
                output, new_states = step_function(
                    current_input, tuple(states) + tuple(constants)
                )
                flat_new_state = tree.flatten(new_states)

                flat_output = tree.flatten(output)
                ta_index_to_write = time if return_all_outputs else 0
                output_ta_t = tuple(
                    ta.write(ta_index_to_write, out)
                    for ta, out in zip(output_ta_t, flat_output)
                )

                new_states = pack_sequence_as(initial_states, flat_new_state)
                return (time + 1, output_ta_t) + tuple(new_states)

            final_outputs = tf.while_loop(
                body=_step,
                loop_vars=(time, output_ta) + states,
                **while_loop_kwargs,
            )
            new_states = final_outputs[2:]

        output_ta = final_outputs[1]

        outputs = tuple(o.stack() for o in output_ta)
        last_output = tuple(o[-1] for o in outputs)

        outputs = pack_sequence_as(output_time_zero, outputs)
        last_output = pack_sequence_as(output_time_zero, last_output)

    if not time_major:
        outputs = tree.map_structure(swap_batch_timestep, outputs)

    return last_output, outputs, new_states


def gru(
    inputs,
    initial_state,
    mask,
    kernel,
    recurrent_kernel,
    bias,
    activation,
    recurrent_activation,
    return_sequences=False,
    go_backwards=False,
    unroll=False,
    time_major=False,
    reset_after=True,
):
    inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major)
    cudnn_supported = cudnn_ok(
        activation,
        recurrent_activation,
        unroll,
        use_bias=bias is not None,
        reset_after=reset_after,
    )
    if not cudnn_supported or not inputs_supported:
        raise NotImplementedError

    from keras_core.backend.tensorflow import Variable

    if isinstance(kernel, Variable):
        kernel = kernel.value
    if isinstance(recurrent_kernel, Variable):
        recurrent_kernel = recurrent_kernel.value
    if isinstance(bias, Variable):
        bias = bias.value

    try:
        return _cudnn_gru(
            inputs,
            initial_state,
            kernel,
            recurrent_kernel,
            bias,
            mask,
            time_major,
            go_backwards,
            return_sequences,
        )
    except tf.errors.InvalidArgumentError:
        # cuDNN op not found.
        raise NotImplementedError
    except tf.errors.NotFoundError:
        # alternative error: device not found for op
        raise NotImplementedError


def _do_gru_arguments_support_cudnn(
    activation,
    recurrent_activation,
    unroll,
    use_bias,
    reset_after,
):
    from keras_core import activations
    from keras_core import ops

    return (
        activation in (activations.tanh, tf.tanh, ops.tanh)
        and recurrent_activation
        in (activations.sigmoid, tf.sigmoid, ops.sigmoid)
        and not unroll
        and use_bias
        and reset_after
    )


def _do_lstm_arguments_support_cudnn(
    activation,
    recurrent_activation,
    unroll,
    use_bias,
):
    from keras_core import activations
    from keras_core import ops

    return (
        activation in (activations.tanh, tf.tanh, ops.tanh)
        and recurrent_activation
        in (activations.sigmoid, tf.sigmoid, ops.sigmoid)
        and not unroll
        and use_bias
    )


def _do_rnn_inputs_support_cudnn(mask, time_major):
    if tf.sysconfig.get_build_info()["is_rocm_build"]:
        if mask is not None:
            return tf.reduce_all(mask)
        return True
    if mask is None:
        return True
    if time_major:
        mask = tf.transpose(mask)
    return tf.logical_and(
        _is_sequence_right_padded(mask),
        tf.logical_not(_has_fully_masked_sequence(mask)),
    )


def _is_sequence_right_padded(mask):
    """Check the mask tensor and see if it right padded.

    For cuDNN kernel, it uses the sequence length param to skip the tailing
    timestep. If the data is left padded, or not a strict right padding (has
    masked value in the middle of the sequence), then cuDNN kernel won't be work
    properly in those cases.

    Left padded data: [[False, False, True, True, True]].
    Right padded data: [[True, True, True, False, False]].
    Mixture of mask/unmasked data: [[True, False, True, False, False]].

    Note that for the mixed data example above, the actually data RNN should see
    are those 2 Trues (index 0 and 2), the index 1 False should be ignored and
    not pollute the internal states.

    Args:
      mask: the Boolean tensor with shape [batch, timestep]

    Returns:
      boolean scalar tensor, whether the mask is strictly right padded.
    """
    max_seq_length = tf.shape(mask)[1]
    count_of_true = tf.reduce_sum(tf.cast(mask, tf.int32), axis=1)
    right_padded_mask = tf.sequence_mask(count_of_true, maxlen=max_seq_length)
    return tf.reduce_all(
        tf.equal(
            tf.cast(mask, dtype="bool"),
            tf.cast(right_padded_mask, dtype="bool"),
        )
    )


def _has_fully_masked_sequence(mask):
    # Cudnn kernel will error out if the input sequence contains any
    # fully masked data. We walk around this issue by rerouting the computation
    # to standard kernel, until the issue on cudnn side has been fixed.  For a
    # fully masked sequence, it will contain all Falses. To make it easy to
    # check, we inverse the boolean, check if any of the sequence has all True.
    return tf.reduce_any(
        tf.reduce_all(tf.logical_not(tf.cast(mask, dtype="bool")), axis=1)
    )


def _standardize_cudnn_weights(weights, biases, shape, transpose_weights=False):
    """Utility function convert variable to cuDNN compatible parameter.

    Note that Keras weights for kernels are different from the cuDNN format.
    Eg.:

    ```
      Keras                 cuDNN
      [[0, 1, 2],  <--->  [[0, 2, 4],
       [3, 4, 5]]          [1, 3, 5]]
    ```

    If the input weights need to be in a unified format, then set
    `transpose_weights=True` to convert the weights.

    Args:
        weights: list of weights for the kernels and recurrent kernels.
        biases: list of biases for individual gate.
        shape: the shape for the converted variables that will be feed to cuDNN.
        transpose_weights: boolean, whether to transpose the weights.

    Returns:
        The converted weights that can be feed to cuDNN ops as param.
    """

    def convert(w):
        return tf.transpose(w) if transpose_weights else w

    weights = [tf.reshape(convert(x), shape) for x in weights]
    biases = [tf.reshape(x, shape) for x in biases]
    return tf.concat(weights + biases, axis=0)


def _compute_sequence_length_from_mask(mask, time_major):
    """Calculate the sequence length tensor (1-D) based on the masking tensor.

    The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For
    any timestep that should be masked, the corresponding field will be False.
    Consider the following example:
      a = [[True, True, False, False],
           [True, True, True, False]]
    It is a (2, 4) tensor, and the corresponding sequence length result should
    be 1D tensor with value [2, 3]. Note that the masking tensor must be right
    padded that could be checked by, e.g., `is_sequence_right_padded()`.

    Args:
      mask: Boolean tensor with shape [batch, timestep] or [timestep, batch] if
        time_major=True.
      time_major: Boolean, which indicates whether the mask is time major or
        batch major.
    Returns:
      sequence_length: 1D int32 tensor.
    """
    timestep_index = 0 if time_major else 1
    return tf.reduce_sum(tf.cast(mask, tf.int32), axis=timestep_index)


def _is_gpu_available():
    return bool(tf.config.list_logical_devices("GPU"))


@tf.function(autograph=False)
def _cudnn_gru(
    inputs,
    initial_state,
    kernel,
    recurrent_kernel,
    bias,
    mask,
    time_major,
    go_backwards,
    return_sequences,
):
    """GRU with cuDNN implementation which is only available for GPU."""
    if mask is not None:
        sequence_lengths = _compute_sequence_length_from_mask(mask, time_major)
    else:
        sequence_lengths = None

    if not time_major and sequence_lengths is None:
        inputs = tf.transpose(inputs, perm=(1, 0, 2))
        seq_axis, batch_axis = (0, 1)
    else:
        seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
    # For init_h, cuDNN expects one more dim of num_layers before or after batch
    # dim for time major or batch major inputs respectively
    init_h = tf.expand_dims(initial_state, axis=seq_axis)

    weights = tf.split(kernel, 3, axis=1)
    weights += tf.split(recurrent_kernel, 3, axis=1)
    # Note that the bias was initialized as shape (2, 3 * units), flatten it to
    # (6 * units)
    bias = tf.split(tf.reshape(bias, [-1]), 6)

    if tf.sysconfig.get_build_info()["is_cuda_build"]:
        # Note that the gate order for cuDNN is different from the canonical
        # format.  canonical format is [z, r, h], whereas cuDNN is [r, z, h].
        # The swap need to be done for kernel, recurrent_kernel, input_bias,
        # recurrent_bias.
        # z is update gate weights.
        # r is reset gate weights.
        # h is output gate weights.
        weights[0], weights[1] = weights[1], weights[0]
        weights[3], weights[4] = weights[4], weights[3]
        bias[0], bias[1] = bias[1], bias[0]
        bias[3], bias[4] = bias[4], bias[3]

    params = _standardize_cudnn_weights(
        weights=weights,
        biases=bias,
        shape=tf.constant([-1]),
        transpose_weights=True,
    )

    if sequence_lengths is not None:
        if go_backwards:
            # Three reversals are required. E.g.,
            # normal input = [1, 2, 3, 0, 0]  # where 0 need to be masked
            # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
            # output_from_cudnn = [6, 5, 4, 0, 0]
            # expected_output = [0, 0, 6, 5 ,4]
            inputs = tf.reverse_sequence(
                inputs,
                sequence_lengths,
                seq_axis=seq_axis,
                batch_axis=batch_axis,
            )
        outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
            input=inputs,
            input_h=init_h,
            input_c=0,
            params=params,
            is_training=True,
            rnn_mode="gru",
            sequence_lengths=sequence_lengths,
            time_major=time_major,
        )
        if go_backwards:
            outputs = tf.reverse_sequence(
                outputs,
                sequence_lengths,
                seq_axis=seq_axis,
                batch_axis=batch_axis,
            )
            outputs = tf.reverse(outputs, axis=[seq_axis])
    else:
        if go_backwards:
            # Reverse axis 0 since the input is already convert to time major.
            inputs = tf.reverse(inputs, axis=[0])
        outputs, h, _, _ = tf.raw_ops.CudnnRNN(
            input=inputs,
            input_h=init_h,
            input_c=0,
            params=params,
            is_training=True,
            rnn_mode="gru",
        )

    last_output = outputs[-1]
    if not time_major and sequence_lengths is None and return_sequences:
        outputs = tf.transpose(outputs, perm=[1, 0, 2])
    state = tf.squeeze(h, axis=seq_axis)

    # In the case of variable length input, the cudnn kernel will fill zeros for
    # the output, whereas the default keras behavior is to bring over the
    # previous output for t-1, so that in the return_sequence=False case, user
    # can quickly get the final effect output instead just 0s at the last
    # timestep.  In order to mimic the default keras behavior, we copy the final
    # h state as the last_output, since it is numerically same as the output.
    if sequence_lengths is not None:
        last_output = state

    # Match CPU return format
    if not return_sequences:
        outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)

    return (
        last_output,
        outputs,
        state,
    )


def cudnn_ok(
    activation,
    recurrent_activation,
    unroll,
    use_bias,
    reset_after=None,
):
    if reset_after is None:
        args_supported = _do_lstm_arguments_support_cudnn(
            activation=activation,
            recurrent_activation=recurrent_activation,
            unroll=unroll,
            use_bias=use_bias,
        )
    else:
        args_supported = _do_gru_arguments_support_cudnn(
            activation=activation,
            recurrent_activation=recurrent_activation,
            unroll=unroll,
            use_bias=use_bias,
            reset_after=reset_after,
        )
    return args_supported and _is_gpu_available()


def lstm(
    inputs,
    initial_state_h,
    initial_state_c,
    mask,
    kernel,
    recurrent_kernel,
    bias,
    activation,
    recurrent_activation,
    return_sequences=False,
    go_backwards=False,
    unroll=False,
    time_major=False,
):
    inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major)
    cudnn_supported = cudnn_ok(
        activation, recurrent_activation, unroll, use_bias=bias is not None
    )
    if not cudnn_supported or not inputs_supported:
        raise NotImplementedError

    from keras_core.backend.tensorflow import Variable

    if isinstance(kernel, Variable):
        kernel = kernel.value
    if isinstance(recurrent_kernel, Variable):
        recurrent_kernel = recurrent_kernel.value
    if isinstance(bias, Variable):
        bias = bias.value

    try:
        return _cudnn_lstm(
            inputs,
            initial_state_h,
            initial_state_c,
            kernel,
            recurrent_kernel,
            bias,
            mask,
            time_major,
            go_backwards,
            return_sequences,
        )
    except tf.errors.InvalidArgumentError:
        # cuDNN op not found.
        raise NotImplementedError
    except tf.errors.NotFoundError:
        # alternative error: device not found for op
        raise NotImplementedError


def _cudnn_lstm(
    inputs,
    initial_state_h,
    initial_state_c,
    kernel,
    recurrent_kernel,
    bias,
    mask,
    time_major,
    go_backwards,
    return_sequences,
):
    if mask is not None:
        sequence_lengths = _compute_sequence_length_from_mask(mask, time_major)
    else:
        sequence_lengths = None

    if not time_major and sequence_lengths is None:
        inputs = tf.transpose(inputs, perm=(1, 0, 2))
        seq_axis, batch_axis = (0, 1)
    else:
        seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
    # For init_h and init_c, cuDNN expects one more dim of num_layers before or
    # after batch dim for time major or batch major inputs respectively
    init_h = tf.expand_dims(initial_state_h, axis=seq_axis)
    init_c = tf.expand_dims(initial_state_c, axis=seq_axis)

    weights = tf.split(kernel, 4, axis=1)
    weights += tf.split(recurrent_kernel, 4, axis=1)
    # cuDNN has an extra set of bias for inputs, we disable them (setting to 0),
    # so that mathematically it is same as the canonical LSTM implementation.
    full_bias = tf.concat((tf.zeros_like(bias), bias), 0)

    if tf.sysconfig.get_build_info()["is_rocm_build"]:
        # ROCm MIOpen's weight sequence for LSTM is different from both
        # canonical and Cudnn format
        # MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o]
        # i is input gate weights.
        # f is forget gate weights.
        # o is output gate weights.
        # c is cell gate weights.
        weights = [weights[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]
        # full_bias is a tensor of shape (8*n,)
        full_bias = tf.split(full_bias, 8, axis=0)
        full_bias = [full_bias[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]

    params = _standardize_cudnn_weights(
        weights=weights,
        biases=tf.split(full_bias, 8),
        shape=tf.constant([-1]),
        transpose_weights=True,
    )

    if sequence_lengths is not None:
        if go_backwards:
            # Three reversals are required. E.g.,
            # normal input = [1, 2, 3, 0, 0]  # where 0 need to be masked
            # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
            # output_from_cudnn = [6, 5, 4, 0, 0]
            # expected_output = [0, 0, 6, 5 ,4]
            inputs = tf.reverse_sequence(
                inputs,
                sequence_lengths,
                seq_axis=seq_axis,
                batch_axis=batch_axis,
            )
        outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(
            input=inputs,
            input_h=init_h,
            input_c=init_c,
            params=params,
            is_training=True,
            rnn_mode="lstm",
            sequence_lengths=sequence_lengths,
            time_major=time_major,
        )
        if go_backwards:
            outputs = tf.reverse_sequence(
                outputs,
                sequence_lengths,
                seq_axis=seq_axis,
                batch_axis=batch_axis,
            )
            outputs = tf.reverse(outputs, axis=[seq_axis])
    else:
        # # Fill the array with shape [batch] with value of max timesteps.
        # sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
        #                                  array_ops.shape(inputs)[0])
        if go_backwards:
            # Reverse axis 0 since the input is already convert to time major.
            inputs = tf.reverse(inputs, axis=[0])
        outputs, h, c, _ = tf.raw_ops.CudnnRNN(
            input=inputs,
            input_h=init_h,
            input_c=init_c,
            params=params,
            is_training=True,
            rnn_mode="lstm",
        )

    last_output = outputs[-1]
    if not time_major and sequence_lengths is None and return_sequences:
        outputs = tf.transpose(outputs, perm=[1, 0, 2])
    h = tf.squeeze(h, axis=seq_axis)
    c = tf.squeeze(c, axis=seq_axis)

    # In the case of variable length input, the cudnn kernel will fill zeros for
    # the output, whereas the default keras behavior is to bring over the
    # previous output for t-1, so that in the return_sequence=False case, user
    # can quickly get the final effect output instead just 0s at the last
    # timestep.  In order to mimic the default keras behavior, we copy the final
    # h state as the last_output, since it is numerically same as the output.
    if sequence_lengths is not None:
        last_output = h

    # Match CPU return format
    if not return_sequences:
        outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)

    return (last_output, outputs, [h, c])
